package fr.lip6.meta.ple.configsgenerator;

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;

import org.sat4j.minisat.SolverFactory;
import org.sat4j.reader.InstanceReader;
import org.sat4j.reader.ParseFormatException;
import org.sat4j.reader.Reader;
import org.sat4j.specs.ContradictionException;
import org.sat4j.specs.IProblem;
import org.sat4j.specs.ISolver;
import org.sat4j.specs.TimeoutException;
import org.sat4j.tools.ModelIterator;

import guidsl.*;

/**
 * Automatic generation of config files.
 * inputs: model.m feature_order
 * @author Luz
 */
public class ConfigsGenerator {
	
	private static FeatureComparator comparator = null;
	
	private static ArrayList<variable> userVars = new ArrayList<variable>();
	
	private static HashMap<String,Integer[]> alreadyGenerated = new HashMap<String,Integer[]>();

	public static void toCnfFile(String model, String output) {
		Tool t = new Tool(model);
		cnfModel m = t.getCnfModel();
		System.out.println(m.model());
		variable.dumpVtable();
		
        FileWriter file = null;
        PrintWriter pw = null;
        StringReader r = new StringReader(m.model());
        BufferedReader br = new BufferedReader(r);
        try
        {
            file = new FileWriter(output);
            pw = new PrintWriter(file);
            String line;
            while ((line = br.readLine()) != null)
            	if (!line.startsWith("c "))
            		pw.println(line);

        } catch (Exception e) {
            e.printStackTrace();
        } finally {
           try {
           if (null != file)
              file.close();
           } catch (Exception e2) {
              e2.printStackTrace();
           }
        }
	}
	
	private static Integer[] cast (int[] a) {
		Integer[] v = new Integer[a.length];
		for (int i=0; i < v.length; i++)
			v[i] = a[i];
		return v;
	}
	
	public static void solve(String cnfFile) {
        ISolver solver = SolverFactory.newDefault();
        ModelIterator mi = new ModelIterator(solver);
        solver.setTimeout(3600); // 1 hour timeout
        Reader reader = new InstanceReader(mi);

        // filename is given on the command line
        try {
            boolean unsat = true;
            IProblem problem = reader.parseInstance(cnfFile);
            while (problem.isSatisfiable()) {
               unsat = false;
               int [] model = problem.model();
               propagate(cast(model));
            }
            if (unsat)
                System.out.println("Unsat model");
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (ParseFormatException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        } catch (ContradictionException e) {
            System.out.println("Unsatisfiable (trivial)!");
        } catch (TimeoutException e) {
            System.out.println("Timeout, sorry!");
        }
    }
	
	public static void toConfigFile(Integer vars[], String fileName) throws IOException {
        FileWriter file = null;
        PrintWriter pw = null;
        
        file = new FileWriter(fileName);
        pw = new PrintWriter(file);
        
        if (comparator == null)
        	Arrays.sort(vars);
        else
        	Arrays.sort(vars, comparator);
		
		for (int i : vars)
			if (i>0) { //positive values == selected features
				String name = variable.findVar(i);
				variable var = variable.find(name);
				if (var.type == variable.Prim) {
					pw.println(name);
				}
			}
		file.close();
	}
	
	private static void propagate(Integer[] solution) {
		ArrayList<Integer> userSelec = new ArrayList<Integer>();
		
		if (comparator == null)
        	Arrays.sort(solution);
        else
        	Arrays.sort(solution, comparator);
		
		for (int i : solution)
			if (i > 0) {
				variable v = variable.find(variable.findVar(i));
				if (userVars.contains(v)) {
					userSelec.add(i);
				}
		}

		System.out.print("user selected: ");
		for (int i: userSelec) {
			System.out.print(variable.findVar(i) + " ");
		}
		System.out.println();
		String key = userSelec.toString();
		if (alreadyGenerated.containsKey(key)) {
			Integer [] old = alreadyGenerated.get(key);
			System.out.println("old: " + countPositive(old) + " new: " + countPositive(solution));
			if (countPositive(old) > countPositive(solution)) {
				alreadyGenerated.remove(key);
				alreadyGenerated.put(key, solution);
			}
		}
		else alreadyGenerated.put(key, solution);
	}
	
	private static int countPositive(Integer[] a) {
		int s = 0;
		for (int i : a)
			if (i > 0) s++;
		return s;
	}
	
	public static void initComparator (String fileName) throws IOException {
		HashMap<Integer,Integer> order = new HashMap<Integer,Integer>();
		HashMap<String,Integer> names = new HashMap<String,Integer>();
		int n = 1;
		
		//Names	
		FileReader fr = new FileReader(fileName);
		BufferedReader br = new BufferedReader(fr);
		String line;
		while ((line = br.readLine()) != null) 
			names.put(line, n++);
		fr.close();
		
		//Order
		for (int i = 1; i <= variable.vtsize; i++) {
			Integer rank = names.get(variable.findVar(i));
			if (rank == null) rank = 0;
			order.put(i, rank);
			System.out.println(variable.findVar(i) + ": " + rank);
		}
		
		comparator = new FeatureComparator(order);
	}
	
	public static void main (String args[]) throws IOException {
		String cnfFile = "constraints.cnf";
		int n = 1;
		
		toCnfFile(args[0], cnfFile);
		initUserVars();
		if (args.length > 1) {
			System.out.println("using order specification");
			initComparator(args[1]);
		}
		
		solve(cnfFile);
		
		for (Integer[] sol : alreadyGenerated.values()) {
			toConfigFile(sol,"generatedConfs/c" + n++ + ".config");
		}
        System.out.println("Generated " + (n-1) + " config files");
	}
	
	private static void inheritHidden(production p) {
		for (pattern pat : (ArrayList<pattern>)p.pat)
				inheritHidden(pat);
	}
	
	private static void inheritHidden(pattern p) {
		for (Object t : p.terms) 
			if (t instanceof prod) {
				prod pr = (prod) t;
				inheritHidden(pr.findProduction(pr.name));
			}
			else if (t instanceof prim)
				((prim) t).var.hidden = true;
			else 
				System.out.println("soy " + t.getClass().getName());
	}
	
	private static void initUserVars() {
		userVars.clear();
	    
	    //remove childs of hidden productions
	    System.out.println("\nproductions");
	    Iterator<production> currentPro = production.Ptable.values().iterator();
	    while (currentPro.hasNext()) {
	    	production p = (production) currentPro.next();
	    	if (p.var.hidden)
	    		inheritHidden(p);
	    }
	    
	    System.out.println("\npatterns");
	    Iterator<pattern> currentP = pattern.Ttable.values().iterator();
	    while (currentP.hasNext()) {
	    	pattern p = (pattern) currentP.next();

	    	if (p.var.hidden)
	    		inheritHidden(p);
	    }
	    //extract the not-hidden vars from the varCollection (hashmap)
	    Iterator<variable> currentVar = variable.Vtable.values().iterator();
//	    System.out.println("\nadded:");
	    while(currentVar.hasNext()){
	        variable var = (variable)currentVar.next();
	        if(var.type == variable.Prim && !var.hidden){
//	        	System.out.println(var.name);
	            userVars.add(var);
	        }
	    }
	    
	}
		 
}

class FeatureComparator implements Comparator<Integer> {
	
	HashMap<Integer, Integer> order = null;
	
	public FeatureComparator (HashMap<Integer, Integer> order) {
		this.order = order;
	}

	@Override
	public int compare(Integer n1, Integer n2) {
		if (n1 <= 0 || n2 <= 0) {
			if (n1 < n2) return -1;
			if (n2 < n1) return 1;
			return 0;
		}
		int o1 = order.get(n1);
		int o2 = order.get(n2);
		if (o1 < o2) return -1;
		if (o1 > o2) return 1;
		return 0;
	}
	
}
